import cv2
import numpy as np
from reedsolo import RSCodec

# ==================================================
# REED SOLOMON
# ==================================================

rsc = RSCodec(4)

# ==================================================
# CAMERA
# ==================================================

cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)

cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)

print("r=read | wasd=move | WASD=fast | g/h=zoom | G/H=fast | i/o=rotate | q=quit")

# ==================================================
# GRID POINTS
# ==================================================

base_points = [
    (668, 387),
    (668, 433),
    (668, 477),
    (668, 522),
    (668, 567),
    (668, 610),
    (668, 655),
    (668, 700)
]

columns = []

# ==================================================
# 20 COLUMNS
# ==================================================

for i in range(8):

    columns.append([
        (x + i * 45, y)
        for (x, y) in base_points
    ])

offset = 8 * 45

for i in range(3):

    columns.append([
        (x + offset + i * 42, y)
        for (x, y) in base_points
    ])

offset += 3 * 42

for i in range(9):

    columns.append([
        (x + offset + i * 45, y)
        for (x, y) in base_points
    ])

# ==================================================
# FINE ADJUSTMENTS
# ==================================================

columns[8] = [
    (x - 1, y)
    for (x, y) in columns[8]
]

for i in range(10, min(20, len(columns))):

    shift = 2

    if i == 11:
        shift = 4

    columns[i] = [
        (x + shift, y)
        for (x, y) in columns[i]
    ]

# ==================================================
# THRESHOLD
# ==================================================

THRESHOLD = 100

# ==================================================
# CARD OUTLINE
# ==================================================

poly_pts = np.array([
    [430, 317],
    [1555, 317],
    [1555, 771],
    [430, 771]
], np.float32)

# ==================================================
# TRANSFORM CONTROLS
# ==================================================

offset_x = 0
offset_y = 0

scale = 1.0
rotation = 0.0

# ==================================================
# OVERLAY DATA
# ==================================================

last_matrix = None

last_status = ""
last_color = (255, 255, 255)

decimal_values = []
rs_values = []

fixed_bytes = []
missing_bytes = []

# ==================================================
# MAIN LOOP
# ==================================================

while True:

    ret, frame = cap.read()

    if not ret:
        break

    h, w = frame.shape[:2]

    cx, cy = w // 2, h // 2

    display = frame.copy()

    # ==================================================
    # LEFT OVERLAY
    # ==================================================

    if last_matrix is not None:

        overlay_x = 20
        overlay_y = 60

        cv2.rectangle(
            display,
            (10, 20),
            (470, 720),
            (0, 0, 0),
            -1
        )

        cv2.putText(
            display,
            "Decoded bytes:",
            (overlay_x, overlay_y),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.7,
            (0, 255, 0),
            2
        )

        # --------------------------------------------------
        # DATA BYTES
        # --------------------------------------------------

        for i, val in enumerate(decimal_values):

            cv2.putText(
                display,
                f"D{i}: {val}",
                (overlay_x, overlay_y + 35 + i * 22),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.6,
                (255, 255, 255),
                1
            )

        # --------------------------------------------------
        # RS BYTES
        # --------------------------------------------------

        rs_y = (
            overlay_y +
            35 +
            len(decimal_values) * 22 +
            20
        )

        for i, rs in enumerate(rs_values):

            cv2.putText(
                display,
                f"RS{i}: {rs}",
                (overlay_x, rs_y + i * 22),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.6,
                (0, 255, 255),
                2
            )

        # --------------------------------------------------
        # FIXED BYTES
        # --------------------------------------------------

        fix_y = rs_y + 120

        if fixed_bytes:

            cv2.putText(
                display,
                "FIXED:",
                (overlay_x, fix_y),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.7,
                (0, 255, 255),
                2
            )

            for i, (idx, old, new) in enumerate(fixed_bytes):

                cv2.putText(
                    display,
                    f"D{idx}: {old}->{new}",
                    (overlay_x, fix_y + 30 + i * 22),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.55,
                    (255, 255, 255),
                    1
                )

        # --------------------------------------------------
        # MISSING BYTES
        # --------------------------------------------------

        miss_y = fix_y + 140

        if missing_bytes:

            cv2.putText(
                display,
                "MISSING:",
                (overlay_x, miss_y),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.7,
                (0, 0, 255),
                2
            )

            for i, name in enumerate(missing_bytes):

                cv2.putText(
                    display,
                    name,
                    (overlay_x, miss_y + 30 + i * 22),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.55,
                    (255, 255, 255),
                    1
                )

    # ==================================================
    # TRANSFORM
    # ==================================================

    def transform(x, y):

        x0 = (x - cx) * scale
        y0 = (y - cy) * scale

        xr = (
            x0 * np.cos(rotation)
            - y0 * np.sin(rotation)
        )

        yr = (
            x0 * np.sin(rotation)
            + y0 * np.cos(rotation)
        )

        x2 = cx + xr + offset_x
        y2 = cy + yr + offset_y

        return int(x2), int(y2)

    # ==================================================
    # READ CARD
    # ==================================================

    def read_card():

        matrix = []

        all_values = []

        temp_columns = []

        # --------------------------------------------------
        # TRANSFORM ALL COLUMNS
        # --------------------------------------------------

        for col in columns:

            temp_col = []

            for (x, y) in col:

                x0 = (x - cx) * scale
                y0 = (y - cy) * scale

                xr = (
                    x0 * np.cos(rotation)
                    - y0 * np.sin(rotation)
                )

                yr = (
                    x0 * np.sin(rotation)
                    + y0 * np.cos(rotation)
                )

                tx = int(cx + xr + offset_x)
                ty = int(cy + yr + offset_y)

                temp_col.append((tx, ty))

            temp_columns.append(temp_col)

        # --------------------------------------------------
        # READ ALL 20 BYTES
        # --------------------------------------------------

        for col in temp_columns:

            bits = []

            for (x, y) in col:

                if 0 <= x < w and 0 <= y < h:

                    b, g, r_pix = frame[y, x]

                    gray = int(
                        0.299 * r_pix +
                        0.587 * g +
                        0.114 * b
                    )

                    bits.append(
                        1 if gray < THRESHOLD else 0
                    )

                else:
                    bits.append(0)

            matrix.append(bits)

            value = sum(
                b << i
                for i, b in enumerate(bits)
            )

            all_values.append(value)

        # ==================================================
        # REBUILD RS PACKET
        # ==================================================

        data_values = []
        rs_bytes_local = []

        for block in range(4):

            start = block * 5

            data_values.extend(
                all_values[start:start + 4]
            )

            rs_bytes_local.append(
                all_values[start + 4]
            )

        encoded_packet = bytes(
            data_values + rs_bytes_local
        )

        # ==================================================
        # RS DECODE
        # ==================================================

        fixed = []
        missing = []

        try:

            decoded = rsc.decode(
                encoded_packet
            )

            corrected_values = list(
                decoded[0]
            )

            rs_ok = True

            # ----------------------------------------------
            # Detect corrected bytes
            # ----------------------------------------------

            for i in range(16):

                original = data_values[i]
                corrected = corrected_values[i]

                if original != corrected:

                    fixed.append(
                        (i, original, corrected)
                    )

            # ----------------------------------------------
            # Detect unreadable bytes
            # ----------------------------------------------

            for i, bits in enumerate(matrix):

                # if ALL bits are identical,
                # byte may be suspicious
                same_bits = all(
                    b == bits[0]
                    for b in bits
                )

                # only flag if RS actually corrected something
                if same_bits and i < 16:

                    original = data_values[i]

                    corrected = corrected_values[i]

                    if original != corrected:
                        missing.append(
                            f"D{i}"
                        )
        except Exception as e:

            corrected_values = []

            rs_ok = False

            print("RS ERROR:", e)

        return (
            rs_ok,
            matrix,
            corrected_values,
            rs_bytes_local,
            fixed,
            missing
        )

    # ==================================================
    # DRAW CARD OUTLINE
    # ==================================================

    poly_scaled = np.array([
        transform(x, y)
        for (x, y) in poly_pts
    ])

    cv2.polylines(
        display,
        [poly_scaled.reshape((-1, 1, 2))],
        True,
        (0, 255, 0),
        2
    )

    # ==================================================
    # ORIENTATION CIRCLE
    # ==================================================

    circle_center = transform(505, 402)

    cv2.circle(
        display,
        circle_center,
        20,
        (0, 255, 0),
        2
    )

    # ==================================================
    # DRAW BIT POINTS
    # ==================================================

    shifted_columns = []

    for col in columns:

        new_col = []

        for (x, y) in col:

            tx, ty = transform(x, y)

            new_col.append((tx, ty))

            cv2.circle(
                display,
                (tx, ty),
                5,
                (0, 0, 255),
                -1
            )

        shifted_columns.append(new_col)

    # ==================================================
    # KEYBOARD
    # ==================================================

    key = cv2.waitKey(1) & 0xFF

    # Move
    if key == ord('w'):
        offset_y -= 1

    elif key == ord('s'):
        offset_y += 1

    elif key == ord('a'):
        offset_x -= 1

    elif key == ord('d'):
        offset_x += 1

    # Fast move
    elif key == ord('W'):
        offset_y -= 5

    elif key == ord('S'):
        offset_y += 5

    elif key == ord('A'):
        offset_x -= 5

    elif key == ord('D'):
        offset_x += 5

    # Zoom
    elif key == ord('g'):
        scale *= 0.995

    elif key == ord('h'):
        scale /= 0.995

    # Fast zoom
    elif key == ord('G'):
        scale *= 0.97

    elif key == ord('H'):
        scale /= 0.97

    # Rotation
    elif key == ord('i'):
        rotation -= 0.002

    elif key == ord('o'):
        rotation += 0.002

    # Fast rotation
    elif key == ord('I'):
        rotation -= 0.01

    elif key == ord('O'):
        rotation += 0.01

    # ==================================================
    # READ
    # ==================================================

    elif key == ord('r'):

        (
            ok,
            matrix,
            values,
            rs_found,
            fixed,
            missing
        ) = read_card()

        last_matrix = matrix

        decimal_values = values

        rs_values = rs_found

        fixed_bytes = fixed

        missing_bytes = missing

        if ok:

            if fixed or missing:

                last_status = (
                    f"RS FIXED "
                    f"{len(fixed)} BYTES"
                )

            else:

                last_status = "RS OK"

            last_color = (0, 255, 0)

            print("")
            print("FOUND VALID CARD")
            print("")

            print("DATA:")
            print(decimal_values)

            print("")
            print("RS:")
            print(rs_values)

            if fixed:

                print("")
                print("FIXED BYTES:")

                for idx, old, new in fixed:

                    print(
                        f"D{idx}: "
                        f"{old} -> {new}"
                    )

            if missing:

                print("")
                print("MISSING / DAMAGED:")
                print(missing)

        else:

            last_status = "RS FAILED"
            last_color = (0, 0, 255)

            print("")
            print("RS FAILED")

    elif key == ord('q'):
        break

    # ==================================================
    # RIGHT OVERLAY
    # ==================================================

    if last_matrix is not None:

        cell = 18

        rows = 8
        cols = len(last_matrix)

        tx = w - cols * cell - 20
        ty = 20

        cv2.rectangle(
            display,
            (tx - 10, ty - 10),
            (
                tx + cols * cell + 10,
                ty + rows * cell + 35
            ),
            (0, 0, 0),
            -1
        )

        for c in range(cols):

            for r in range(rows):

                val = last_matrix[c][r]

                color = (
                    (0, 255, 0)
                    if val else
                    (100, 100, 100)
                )

                cv2.putText(
                    display,
                    str(val),
                    (
                        tx + c * cell,
                        ty + r * cell + 14
                    ),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.5,
                    color,
                    1
                )

        cv2.putText(
            display,
            last_status,
            (tx, ty + rows * cell + 20),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.7,
            last_color,
            2
        )

    cv2.imshow("Camera", display)

cap.release()
cv2.destroyAllWindows()